{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# module load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import te\n",
    "from tvm.contrib import cc, utils, popen_pool\n",
    "import sys\n",
    "import numpy as np\n",
    "import subprocess\n",
    "import tvm.testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "runtime_py = \"\"\"\n",
    "import os\n",
    "import sys\n",
    "\n",
    "os.environ[\"TVM_USE_RUNTIME_LIB\"] = \"1\"\n",
    "import tvm\n",
    "from tvm import te\n",
    "import numpy as np\n",
    "path_dso = sys.argv[1]\n",
    "dtype = sys.argv[2]\n",
    "ff = tvm.runtime.load_module(path_dso)\n",
    "a = tvm.nd.array(np.zeros(10, dtype=dtype))\n",
    "ff(a)\n",
    "np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))\n",
    "print(\"Finish runtime checking...\")\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 动态模块加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "from pathlib import Path\n",
    "dtype = \"int64\"\n",
    "temp_dir = tempfile.mkdtemp(dir=\".temp\")\n",
    "temp_dir = Path(temp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_object(names, target = \"llvm\", dtype=\"int64\"):\n",
    "    n = te.size_var(\"n\")\n",
    "    Ab = tvm.tir.decl_buffer((n,), dtype)\n",
    "    i = te.var(\"i\")\n",
    "    # for i in 0 to n-1:\n",
    "    stmt = tvm.tir.For(\n",
    "        i,\n",
    "        0,\n",
    "        n - 1,\n",
    "        tvm.tir.ForKind.SERIAL,\n",
    "        tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]),\n",
    "    )\n",
    "    mod = tvm.IRModule.from_expr(\n",
    "        tvm.tir.PrimFunc([Ab], stmt).with_attr(\"global_symbol\", \"main\")\n",
    "    )\n",
    "    m = tvm.tir.build(mod, target=target)\n",
    "    for name in names:\n",
    "        m.save(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_obj = str(temp_dir/\"test.o\")\n",
    "path_ll = str(temp_dir/\"test.ll\")\n",
    "path_bc = str(temp_dir/\"test.bc\")\n",
    "path_dso = str(temp_dir/\"test.so\")\n",
    "targets = [\"llvm\", \"llvm -jit=mcjit\"]\n",
    "for target in targets:\n",
    "    save_object([path_obj, path_ll, path_bc], target, dtype)\n",
    "    cc.create_shared(path_dso, [path_obj])\n",
    "    f1 = tvm.runtime.load_module(path_dso)\n",
    "    f2 = tvm.runtime.load_module(path_ll)\n",
    "    a = tvm.nd.array(np.zeros(10, dtype=dtype))\n",
    "    f1(a)\n",
    "    np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))\n",
    "    a = tvm.nd.array(np.zeros(10, dtype=dtype))\n",
    "    f2(a)\n",
    "    np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))\n",
    "\n",
    "    path_runtime_py = temp_dir/\"runtime.py\"\n",
    "    with open(path_runtime_py, \"w\") as fo:\n",
    "        fo.write(runtime_py)\n",
    "\n",
    "    proc = subprocess.run(\n",
    "        [sys.executable, path_runtime_py, path_dso, dtype],\n",
    "        stdout=subprocess.PIPE,\n",
    "        stderr=subprocess.STDOUT,\n",
    "    )\n",
    "    assert proc.returncode == 0, f\"{proc.args} exited with {proc.returncode}: {proc.stdout}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模块 dump"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skip because vulkan is not enabled\n",
      "Skip because vulkan is not enabled\n",
      "Skip because opencl is not enabled\n",
      "Skip because opencl is not enabled\n",
      "Skip because metal is not enabled\n",
      "Skip because metal is not enabled\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['PATH'] += ':/usr/local/cuda/bin' # 保证 nvcc 可以被找到\n",
    "# graph\n",
    "n = tvm.runtime.convert(1024)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name=\"B\")\n",
    "\n",
    "sch = tvm.tir.Schedule(te.create_prim_func([A, B]))\n",
    "# create iter var and assign them tags.\n",
    "num_thread = 8\n",
    "bx, tx = sch.split(sch.get_loops(\"B\")[0], factors=[None, num_thread])\n",
    "sch.bind(bx, \"blockIdx.x\")\n",
    "sch.bind(tx, \"threadIdx.x\")\n",
    "\n",
    "def check_device(device):\n",
    "    dev = tvm.device(device, 0)\n",
    "    if not tvm.testing.device_enabled(device):\n",
    "        print(\"Skip because %s is not enabled\" % device)\n",
    "        return\n",
    "    temp = utils.tempdir()\n",
    "    f = tvm.compile(sch.mod, target=device)\n",
    "\n",
    "    path_dso = temp.relpath(\"dev_lib.so\")\n",
    "    # test cross compiler function\n",
    "    f.export_library(path_dso, fcompile=cc.cross_compiler(\"g++\"))\n",
    "\n",
    "    def popen_check():\n",
    "        import tvm\n",
    "\n",
    "        f1 = tvm.runtime.load_module(path_dso)\n",
    "        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev)\n",
    "        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)\n",
    "        f1(a, b)\n",
    "        np.testing.assert_equal(b.numpy(), a.numpy() + 1)\n",
    "\n",
    "    # system lib should be loaded in different process\n",
    "    worker = popen_pool.PopenWorker()\n",
    "    worker.send(popen_check)\n",
    "    worker.recv()\n",
    "\n",
    "def check_c(device):\n",
    "    dev = tvm.device(device, 0)\n",
    "    if not tvm.testing.device_enabled(device):\n",
    "        print(\"Skip because %s is not enabled\" % device)\n",
    "        return\n",
    "    f = tvm.compile(sch.mod, target=tvm.target.Target(device, host=\"c\"))\n",
    "    a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev)\n",
    "    b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)\n",
    "    f[\"main\"](a, b)\n",
    "    np.testing.assert_equal(b.numpy(), a.numpy() + 1)\n",
    "\n",
    "for device in [\"cuda\", \"vulkan\", \"opencl\", \"metal\"]:\n",
    "    check_device(device)\n",
    "    check_c(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 合并模块打包为一个动态库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running popen check\n"
     ]
    }
   ],
   "source": [
    "# graph\n",
    "nn = 12\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name=\"B\")\n",
    "mod1 = tvm.IRModule.from_expr(te.create_prim_func([A, B]).with_attr(\"global_symbol\", \"myadd1\"))\n",
    "mod2 = tvm.IRModule.from_expr(te.create_prim_func([A, B]).with_attr(\"global_symbol\", \"myadd2\"))\n",
    "\n",
    "def check_llvm():\n",
    "    dev = tvm.cpu(0)\n",
    "    temp = utils.tempdir()\n",
    "    fadd1 = tvm.tir.build(mod1, \"llvm\")\n",
    "    fadd2 = tvm.tir.build(mod2, \"llvm\")\n",
    "    path1 = temp.relpath(\"myadd1.o\")\n",
    "    path2 = temp.relpath(\"myadd2.o\")\n",
    "    path_dso = temp.relpath(\"mylib.so\")\n",
    "    fadd1.save(path1)\n",
    "    fadd2.save(path2)\n",
    "    # create shared library with multiple functions\n",
    "    cc.create_shared(path_dso, [path1, path2])\n",
    "    m = tvm.runtime.load_module(path_dso)\n",
    "    fadd1 = m[\"myadd1\"]\n",
    "    fadd2 = m[\"myadd2\"]\n",
    "    a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), dev)\n",
    "    b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), dev)\n",
    "    fadd1(a, b)\n",
    "    np.testing.assert_equal(b.numpy(), a.numpy() + 1)\n",
    "    fadd2(a, b)\n",
    "    np.testing.assert_equal(b.numpy(), a.numpy() + 1)\n",
    "\n",
    "def check_system_lib():\n",
    "    dev = tvm.cpu(0)\n",
    "    if not tvm.testing.device_enabled(\"llvm\"):\n",
    "        print(\"Skip because llvm is not enabled\")\n",
    "        return\n",
    "    temp = utils.tempdir()\n",
    "    print(\"Running popen check\")\n",
    "    fadd1 = tvm.tir.build(mod1.with_attr(\"system_lib_prefix\", \"\"), \"llvm\")\n",
    "    fadd2 = tvm.tir.build(mod2.with_attr(\"system_lib_prefix\", \"\"), \"llvm\")\n",
    "    path1 = temp.relpath(\"myadd1.o\")\n",
    "    path2 = temp.relpath(\"myadd2.o\")\n",
    "    path_dso = temp.relpath(\"mylib.so\")\n",
    "    fadd1.save(path1)\n",
    "    fadd2.save(path2)\n",
    "    cc.create_shared(path_dso, [path1, path2])\n",
    "\n",
    "    def popen_check():\n",
    "        import tvm.runtime\n",
    "        import ctypes\n",
    "\n",
    "        # Load dll, will trigger system library registration\n",
    "        ctypes.CDLL(path_dso)\n",
    "        # Load the system wide library\n",
    "        mm = tvm.runtime.system_lib()\n",
    "        a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), dev)\n",
    "        b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), dev)\n",
    "        mm[\"myadd1\"](a, b)\n",
    "        np.testing.assert_equal(b.numpy(), a.numpy() + 1)\n",
    "        mm[\"myadd2\"](a, b)\n",
    "        np.testing.assert_equal(b.numpy(), a.numpy() + 1)\n",
    "\n",
    "    # system lib should be loaded in different process\n",
    "    worker = popen_pool.PopenWorker()\n",
    "    worker.send(popen_check)\n",
    "    worker.recv()\n",
    "\n",
    "if sys.platform != \"win32\":\n",
    "    check_system_lib()\n",
    "check_llvm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
